from utils import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--delta', type=float, required=True)
parser.add_argument('--batch_idx', type=int, required=True)
parser.add_argument('--edits_count', type=int, required=True)
args = parser.parse_args()

for pt_name, pt, pt_col, fn_cols, cols_name in [
    ("ACADBCBD", (0,2,0,3,1,2,1,3,), ("green", "red", "blue", "gray"), fn_colors_random_slow, "rand",),
    ("AB", (0,1,), ("green", "red",), fn_colors_random_slow, "rand",),
    ("AA", (0,0,), ("green", "red",), fn_colors_random_slow, "rand",),
    ("AA (fixed)", (0,0,), ("green", "red",), fn_colors_fixed, "fixed",),
    ("ACBC", (0,2,1,2), ("green", "red", "blue"), fn_colors_random_slow, "rand",),
    ]:
    w = len(pt)
    output_filename = f'15_{args.model.split("/")[1]}_{cols_name}{pt_name}_delta={args.delta}_w={w}_{args.batch_idx}.pkl'
    print(args, output_filename)

    data = sweep_contiguous_edits_multiple(
        {
            "config": { 
                "log": True,
                "verbose": False,
                "data.file_name": "01_wikitext2.json",
                "tokenizer.max_length": 16,
                "tokenizer.truncation": True,
                "tokenizer.truncation_side": "right",
                "tokenizer.padding_side": "left",
                "model.name": args.model,
                "model.token_count": 64,
                "model.do_sample": True,
                "model.num_beams": 4,
                "model.output_scores": False,
                "watermark.delta": args.delta,
                "watermark.colors": pt_col,
                "watermark.color_pattern": pt,
                "watermark.fn_colors": fn_cols,
                "watermark.fn_pattern": fn_pattern_using_prev,
                "edits.count": 1,
                "edits.buffer": w+7,
                "detection.window": w,
                "edit_detection.window": w,
                "edit_detection.tolerance": (-2, +2),
                "edit_detection.target_t1": 0.1,
            },
        },
        config_key="edits.width", config_vals=list(range(args.edits_count)),
        # edit_types=("INSERT",),
        first_batch_idx=args.batch_idx, batch_count=1, batch_size=128,
    )
    with open(output_filename, 'wb') as file:
        pickle.dump(data, file)

"""
# NO SAMPLING
for model in "meta-llama/Llama-2-7b-hf"; do for delta in 5.75; do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=1; done; done
for model in "facebook/opt-1.3b"; do for delta in 4.75; do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=1; done; done
# SAMPLE NUM_BEAMS=4
for model in "meta-llama/Llama-2-7b-hf"; do for delta in 4.6 0; do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=4 --batch_idx=1; done; done
for model in "facebook/opt-1.3b"; do for delta in 3.8 0; do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=4 --batch_idx=1; done; done

# DELTA -> PPL mapping
for model in "meta-llama/Llama-2-7b-hf" "facebook/opt-1.3b"; do for delta in 0 3.8 4.0 4.2 4.4 4.6 4.8 5.0 5.2 5.4 5.6 5.8 6.0 6.2 6.4 6.6 6.8 7.0 7.2 7.4 7.6 7.8 8.0 8.2 8.4 8.6 8.8 9.0 9.2;\
    do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=1 --batch_idx=1; done; done

# EVERYTHING
for model in "meta-llama/Llama-2-7b-hf" "facebook/opt-1.3b"; do for delta in 4.6 5.0 5.4 5.8 6.2 6.6 7.0 7.4;\
    do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=1; done; done


# Unigram
for model in "meta-llama/Llama-2-7b-hf" "facebook/opt-1.3b"; do for delta in 4.6;\
    do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=1; done; done
for model in "meta-llama/Llama-2-7b-hf" "facebook/opt-1.3b"; do for delta in 5.8;\
    do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=1; done; done


# 1,000 5.8
for idx in 1; do for model in "meta-llama/Llama-2-7b-hf" "facebook/opt-1.3b"; do for delta in 3.8 4.8 5.8 6.8 7.8;\
    do python 15_exp_edit_detection_vs_width.py --model="$model" --delta=$delta --edits_count=8 --batch_idx=$idx; done; done; done
"""